#!/usr/bin/env python3
'''
Author: Paschalis Agapitos
Project: Mestizajes

Loads a parquet dataset containing Wikipedia notability features, normalizes
and lemmatizes those features, groups semantically similar occupations with
sentence embeddings, and assigns broader fields of human activity through
similarity-based categorization. The processed results are then written back 
to a parquet file for downstream analysis.
'''

import logging
import os
import random
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import ast
import nltk
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from nltk.stem import WordNetLemmatizer
from sentence_transformers import SentenceTransformer
from sklearn.cluster import AgglomerativeClustering

# suppress verbose model loading messages from transformers/safetensors
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("sentence_transformers").setLevel(logging.WARNING)
os.environ["SAFETENSORS_FAST_GPU"] = "1"  # also suppresses some safetensors logging

# default configuration values intended to be easy to override from a notebook
DEFAULT_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
DEFAULT_CHUNKSIZE = 1024
DEFAULT_SIMILARITY_THRESHOLD = 0.80 # the highest the less false positives
DEFAULT_BATCH_SIZE = 128
DEFAULT_SEED = 42

DEFAULT_RAW_CATEGORIES: List[str] = [
    "Science", "Research", "Technology", "Physicist", "Aerospace", "Engineering", "Mathematics",
    "Scholar", "Education", "Philosophy", "Literature", "Poetry", "Translator", "Critic",
    "Art", "Artisan", "Potter", "Sport", "Athlete", "Racer", "Fencer", "Baseman", "Curler",
    "Sailor", "Business", "Economics", "Entrepreneurship", "Philanthropist", "Merchant",
    "Farmer", "Agriculture", "Religious", "Mysticism", "Spirituality", "Military", "Photographer",
    "Law Enforcement", "Advocate", "Royal", "Monarch", "Ruler", "Politics", "Government",
    "Governor", "Aristocrat", "Count", "Fashion", "Model", "Journalism", "Mayor,", "Media",
    "Radio Host", "Theater", "Cinema", "Comedian", "Music", "Healthcare", "Medical", "Canoeist",
    "Social Worker", "Editor", "Architect", "Contractor", "Designer", "Navigator", "Executive",
    "Explorer", "Diarist", "Jurist", "Botanist", "Chess grandmaster", "Leader", "Geographer",
    "Manager", "Admiral", "Eques", "Sophist", "Amora", "Strategist", "General", "Archaeologist",
    "Statesman", "Augusta", "Defenceman", "Actor", "Actress", "Conductor", "Phychotherapist",
    "punt returner", "YouTuber", "Streamer", "Influencer", "Medium Personality", "Media Personality",
    "Grappler", "Luta Livre", "sprint canoeist", "slalom canoeist", "kayaker", "rower", "return specialist"
]

# keyword overrides: if any of these substrings appear in a feature, force
# it to a specific category (bypasses embedding similarity)
KEYWORD_CATEGORY_OVERRIDES: Dict[str, str] = {
    "canoeist": "Sports",
    "kayaker": "Sports",
    "rower": "Sports",
    "luta livre": "Sports",
    "grappler": "Sports",
    "sprint canoeist": "Sports",
    "slalom canoeist": "Sports",
    "return specialist": "Sports",
}


def apply_keyword_override(feature: str) -> Optional[str]:
    """
    Check if a feature string contains any keyword that should force a category.

    Parameters
    ----------
    feature:
        Lowercased, lemmatized feature text.

    Returns
    -------
    str or None
        The forced category label if a keyword matches, otherwise None.
    """
    feature_lower = feature.lower()
    for keyword, category in KEYWORD_CATEGORY_OVERRIDES.items():
        if keyword in feature_lower:
            return category
    return None


def setup_nltk() -> None:
    """Download the NLTK data required for lemmatization."""
    nltk.download("wordnet", quiet=True)
    nltk.download("omw-1.4", quiet=True)


def set_seed(seed: int = DEFAULT_SEED) -> None:
    """Set random seeds for reproducibility across numpy, random, and torch."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


# text processing
lemmatizer = WordNetLemmatizer()


def lemmatize_text(text: str) -> str:
    """
    Lowercase, tokenize on word characters, lemmatize each token, and rejoin.

    Parameters
    ----------
    text:
        Input text.

    Returns
    -------
    str
        Space-joined lemmatized tokens.
    """
    tokens = re.findall(r"\w+", text.lower())
    return " ".join(lemmatizer.lemmatize(tok) for tok in tokens)


def parse_and_trim_notability(
    df: pd.DataFrame,
    column: str = "notability_features"
) -> pd.DataFrame:
    """
    Parse and normalize the notability feature column into lists of lemmatized strings.

    The function is robust to several input formats:
    - None or NaN
    - python lists / tuples / numpy arrays of strings
    - stringified python lists (e.g. "['physicist', 'mathematician']")
    - simple comma-separated strings

    Parameters
    ----------
    df:
        Input dataframe containing a notability feature column.
    column:
        Name of the column that holds notability features.

    Returns
    -------
    pd.DataFrame
        Dataframe with the specified column replaced by lists of lemmatized strings.
    """
    if column in df.columns:

        def clean_feats(feats: Any) -> List[str]:
            # handle explicit empties first
            if feats is None:
                return []

            # already an iterable of strings (list / tuple / ndarray)
            if isinstance(feats, (list, tuple, np.ndarray)):
                feats_list = list(feats)

            # handle a single float-like NaN
            elif isinstance(feats, float) and pd.isna(feats):
                return []

            # handle strings, possibly containing a repr of a python list
            elif isinstance(feats, str):
                try:
                    parsed = ast.literal_eval(feats)
                    if isinstance(parsed, (list, tuple)):
                        feats_list = list(parsed)
                    else:
                        feats_list = []
                except (ValueError, SyntaxError):
                    cleaned_str = feats.strip().lstrip("[").rstrip("]")
                    feats_list = [
                        f.strip("'\" ").strip()
                        for f in cleaned_str.split(",")
                        if f.strip()
                    ]
            else:
                # for any other type, try to iterate; otherwise fall back to empty
                try:
                    feats_list = list(feats)
                except Exception:
                    feats_list = []

            # keeps only the first two features for each entry
            # adjust here if you want to trim the list
            trimmed = feats_list[:2]
            return [lemmatize_text(str(f)) for f in trimmed if str(f).strip()]

        df = df.copy()
        df[column] = df[column].apply(clean_feats)

    return df


def deduplicate_occupations_with_similarity(
    occupations: List[str],
    model: SentenceTransformer,
    device: str,
    similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
    batch_size: int = DEFAULT_BATCH_SIZE
) -> Tuple[Dict[int, List[str]], Dict[int, torch.Tensor]]:
    """
    Cluster occupation strings using cosine similarity of sentence embeddings.

    Parameters
    ----------
    occupations:
        List of unique occupation strings.
    model:
        SentenceTransformer model instance.
    device:
        Device identifier to run computations on ('cuda' or 'cpu').
    similarity_threshold:
        Minimum cosine similarity required for two items to be joined in the same cluster.
    batch_size:
        Batch size used when encoding occupations.

    Returns
    -------
    clusters:
        Mapping from cluster id to list of occupation strings.
    cluster_embs:
        Mapping from cluster id to the mean embedding tensor of the cluster.
    """
    if not occupations:
        return {}, {}

    embeddings = model.encode(
        occupations,
        batch_size=batch_size,
        convert_to_tensor=True,
        device=device,
        show_progress_bar=False,
    )
    embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)

    # cosine similarity for L2-normalized embeddings is the dot product
    sim_matrix = embeddings @ embeddings.T
    dist_matrix = 1 - sim_matrix  # cosine distance matrix

    clustering = AgglomerativeClustering(
        n_clusters=None,
        metric="precomputed",
        linkage="average",
        distance_threshold=1 - similarity_threshold,
    )
    clustering.fit(dist_matrix.cpu().numpy())

    clusters: Dict[int, List[str]] = {}
    for idx, label in enumerate(clustering.labels_):
        clusters.setdefault(label, []).append(occupations[idx])

    # precompute a mapping from occupation to index to avoid repeated .index() calls
    occ_to_idx = {occ: i for i, occ in enumerate(occupations)}

    cluster_embs: Dict[int, torch.Tensor] = {}
    for label, occs in clusters.items():
        idxs = [occ_to_idx[o] for o in occs]
        cluster_embs[label] = torch.mean(embeddings[idxs], dim=0)

    return clusters, cluster_embs


def categorize_clusters(
    clusters: Dict[int, List[str]],
    cluster_embs: Dict[int, torch.Tensor],
    cat_embs: torch.Tensor,
    categories: List[str]
) -> Dict[int, str]:
    """
    Assign a pre-defined category to each cluster based on embedding similarity.

    Parameters
    ----------
    clusters:
        Mapping from cluster id to list of occupations.
    cluster_embs:
        Mapping from cluster id to the mean embedding tensor of each cluster.
    cat_embs:
        Precomputed, normalized embeddings for category labels.
    categories:
        List of category label strings, in the same order as cat_embs.

    Returns
    -------
    dict
        Mapping from cluster id to the best-matching category label.
    """
    if not clusters:
        return {}

    cluster_ids = list(cluster_embs.keys())
    clust_tensor = torch.stack([cluster_embs[c] for c in cluster_ids])
    clust_tensor = torch.nn.functional.normalize(clust_tensor, p=2, dim=1)

    sim_matrix = clust_tensor @ cat_embs.T
    best_cat_indices = sim_matrix.argmax(dim=1).tolist()

    return {
        cluster_ids[i]: categories[best_cat_indices[i]]
        for i in range(len(cluster_ids))
    }


def process_chunk(
    df: pd.DataFrame,
    model: SentenceTransformer,
    cat_embs: torch.Tensor,
    categories: List[str],
    device: str,
    batch_size: int,
    similarity_threshold: float
) -> pd.DataFrame:
    """
    Process a single dataframe chunk: clean feature lists, cluster them,
    and assign high-level categories.

    Parameters
    ----------
    df:
        Input dataframe chunk.
    model:
        SentenceTransformer model instance.
    cat_embs:
        Precomputed embeddings of category labels.
    categories:
        List of category label strings.
    device:
        Device identifier to run computations on ('cuda' or 'cpu').
    batch_size:
        Batch size used when encoding.
    similarity_threshold:
        Threshold used when clustering features.

    Returns
    -------
    pd.DataFrame
        Processed dataframe containing at least the column
        'field_of_human_activity'. Rows with empty notability features may
        be dropped.
    """
    df = df.dropna(subset=["notability_features"]).copy()
    if df.empty:
        return df

    # collect all unique features across the chunk
    all_feats = [f for sub in df["notability_features"] for f in sub]
    unique_feats = list(set(all_feats))

    if not unique_feats:
        df["field_of_human_activity"] = "unknown"
        return df

    # cluster the unique features
    clusters, cluster_embs = deduplicate_occupations_with_similarity(
        unique_feats,
        model,
        device,
        similarity_threshold=similarity_threshold,
        batch_size=batch_size,
    )

    # assign categories to clusters
    cluster_to_category = categorize_clusters(
        clusters,
        cluster_embs,
        cat_embs,
        categories,
    )

    # build mapping from individual feature to its category
    # keyword overrides take precedence over embedding-based assignment
    feat_to_category: Dict[str, str] = {}
    for cid, feats in clusters.items():
        for f in feats:
            override = apply_keyword_override(f)
            if override is not None:
                feat_to_category[f] = override
            else:
                feat_to_category[f] = cluster_to_category[cid]

    def map_feats_to_category(feat_list: List[str]) -> str:
        """Map a list of features to a comma-separated list of categories."""
        if not feat_list:
            return "unknown"
        cats = [feat_to_category.get(f, "unknown") for f in feat_list]
        # remove duplicates while preserving order
        unique_cats = list(dict.fromkeys(cats))
        return ",".join(unique_cats)

    df["field_of_human_activity"] = df["notability_features"].apply(
        map_feats_to_category
    )

    # if you wish to store embeddings for the original feature lists, you can
    # re-enable and adapt the following block:
    #
    # joined_feats = df["notability_features"].apply(
    #     lambda feats: " | ".join(feats)
    # ).tolist()
    # feat_embeddings = model.encode(
    #     joined_feats,
    #     batch_size=batch_size,
    #     convert_to_tensor=True,
    #     device=device,
    # )
    # df["notability_feature_embeddings"] = feat_embeddings.cpu().numpy().tolist()

    return df


def run_notability_pipeline(
    input_file_path: Union[str, Path],
    output_file_path: Union[str, Path],
    raw_categories: Optional[List[str]] = None,
    model_name: str = DEFAULT_MODEL_NAME,
    chunksize: int = DEFAULT_CHUNKSIZE,
    similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
    batch_size: int = DEFAULT_BATCH_SIZE,
    seed: int = DEFAULT_SEED,
    row_limit: Optional[int] = None,
) -> None:
    """
    Run the notability feature pipeline on a parquet file.

    This function is designed to be called from a Jupyter notebook. Only the
    input and output paths are required; all other arguments have defaults and
    can be overridden as needed.

    Parameters
    ----------
    input_file_path:
        Path to the input parquet file with a 'notability_features' column.
    output_file_path:
        Path where the processed parquet file will be written.
    raw_categories:
        Optional list of category labels. If None, DEFAULT_RAW_CATEGORIES
        is used. You can override this from a notebook cell.
    model_name:
        HuggingFace model name for the SentenceTransformer.
    chunksize:
        Number of rows per chunk when streaming the parquet file.
    similarity_threshold:
        Cosine similarity threshold used for clustering notability features.
    batch_size:
        Batch size for encoding text with the sentence embedding model.
    seed:
        Random seed used to make clustering and embedding behavior reproducible.
    row_limit:
        Optional maximum number of rows to process. If None, all rows are
        processed.

    Returns
    -------
    None
        The function writes the processed data to the given output path.
    """
    setup_nltk()
    set_seed(seed)

    device = "cuda" if torch.cuda.is_available() else "cpu"

    input_path = Path(input_file_path)
    output_path = Path(output_file_path)

    categories = [c.lower() for c in (raw_categories or DEFAULT_RAW_CATEGORIES)]

    # load the sentence transformer model and compute category embeddings
    model = SentenceTransformer(model_name, device=device)
    cat_embs = model.encode(
        categories,
        batch_size=len(categories),
        convert_to_tensor=True,
        device=device,
        show_progress_bar=False,
    )
    cat_embs = torch.nn.functional.normalize(cat_embs, p=2, dim=1)

    writer: Optional[pq.ParquetWriter] = None
    total_rows_processed = 0

    try:
        parquet_file = pq.ParquetFile(input_path)
        batch_iterator = parquet_file.iter_batches(batch_size=chunksize)

        for batch in batch_iterator:
            df_chunk = batch.to_pandas()
            if df_chunk.empty:
                continue

            df_chunk = parse_and_trim_notability(df_chunk)
            processed_chunk = process_chunk(
                df_chunk,
                model=model,
                cat_embs=cat_embs,
                categories=categories,
                device=device,
                batch_size=batch_size,
                similarity_threshold=similarity_threshold,
            )

            if processed_chunk.empty:
                continue

            # write the processed chunk to the output parquet file
            table = pa.Table.from_pandas(processed_chunk, preserve_index=False)
            if writer is None:
                writer = pq.ParquetWriter(
                    output_path,
                    table.schema,
                    compression="gzip",
                )
            writer.write_table(table)

            total_rows_processed += len(processed_chunk)

            # optionally stop if the row limit is reached
            if row_limit is not None and total_rows_processed >= row_limit:
                break

    except FileNotFoundError:
        logging.error(f"input file not found: {input_path}")
    except Exception:
        logging.error("an unexpected error occurred during processing", exc_info=True)
    finally:
        if writer is not None:
            writer.close()
